
# {G-formula = {Y, M}, IPW = {M, A}, Mixed = {A, Y}, AIPW = {A, M, Y}}

compute_effect <- function(dat, beta, px, opt){
 
 n = nrow(dat)
 reparam = opt$reparam
 estimator = opt$estimator 
 
 beta_y = beta$beta_y
 beta_m = beta$beta_m
 beta_a = beta$beta_a
 
 dat_a0m0 = process_data(dat, a = 0, m = 0)
 dat_a0m1 = process_data(dat, a = 0, m = 1)
 dat_a1m0 = process_data(dat, a = 1, m = 0)
 dat_a1m1 = process_data(dat, a = 1, m = 1)
 dat_am0 = process_data(dat, a = dat$A, m = 0)
 dat_am1 = process_data(dat, a = dat$A, m = 1)
 dat_a0m = process_data(dat, a = 0, m = dat$M)
 dat_a1m = process_data(dat, a = 1, m = dat$M)
 
 # +++++++++++++++++++++++++++++++++
 # G-formula: 
 # NDE = \sum_{i, M} { E[Y|m,a=1,c_i] - E[Y|m,a=1,c_i] } * p(m|a=0, c_i) * px_i
 # +++++++++++++++++++++++++++++++++
 if (estimator == "G-formula"){
  
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  # idx_m = c(1, match(colnames(model.frame(fmla_m))[-1], colnames(dat)))
  p_m1a0 = 1/(1 + exp(-dat_a0m[, idx_m]%*%beta_m))
  p_m0a0 = 1 - p_m1a0
  
  if (reparam == FALSE){
   idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
   # idx_y = c(1, match(colnames(model.frame(fmla_y))[-1], colnames(dat)))
   y_a0m0 = dat_a0m0[, idx_y]%*%beta_y
   y_a0m1 = dat_a0m1[, idx_y]%*%beta_y
   y_a1m0 = dat_a1m0[, idx_y]%*%beta_y
   y_a1m1 = dat_a1m1[, idx_y]%*%beta_y
  }else{
   p = length(beta_y)
   beta_f = beta_y[1:(p-2)]
   w0 = beta_y[p-1]
   wa = beta_y[p]
   idx_f = match(attributes(beta_f)$names, colnames(dat)) 
   # idx_f = c(1, match(colnames(model.frame(fmla_f))[-1], colnames(dat)))
   f_m1a1c = dat_a1m1[, idx_f]%*%beta_f
   f_m0a1c = dat_a1m0[, idx_f]%*%beta_f
   f_m1a0c = dat_a0m1[, idx_f]%*%beta_f
   f_m0a0c = dat_a0m0[, idx_f]%*%beta_f
   # E[Y | A = 1, M = 1, C]
   y_a1m1 = f_m1a1c - sum(px*( f_m1a1c*p_m1a0 + f_m0a1c*p_m0a0 )) + w0 + wa
   # E[Y | A = 1, M = 0, C]
   y_a1m0 = f_m0a1c - sum(px*( f_m1a1c*p_m1a0 + f_m0a1c*p_m0a0 )) + w0 + wa
   # E[Y | A = 0, M = 1, C]
   y_a0m1 = f_m1a0c - sum(px*( f_m1a0c*p_m1a0 + f_m0a0c*p_m0a0 )) + w0
   # E[Y | A = 0, M = 0, C]
   y_a0m0 = f_m0a0c - sum(px*( f_m1a0c*p_m1a0 + f_m0a0c*p_m0a0 )) + w0
  }
  
  effect = sum( px * ((y_a1m0 - y_a0m0)*p_m0a0 + (y_a1m1 - y_a0m1)*p_m1a0) )
 }
 
 
 # +++++++++++++++++++++++++++++++++
 # IPW:
 # NDE = 1/n*\sum_i{ {I(A=1)/p(A=1|C_i)}*{p(M|a=0,c_i)/p(M|a=1,c_i)}*Y_i - {I(A=0)/p(A=0|c_i)}*Y_i }
 # +++++++++++++++++++++++++++++++++
 if (estimator == "IPW" && reparam == FALSE){
  
  IA1 = dat$A 
  IA0 = 1 - IA1
  
  # idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  # p_a1 = 1/(1 + exp(-as.matrix(dat[, idx_a])%*%beta_a))
  p_a1 = beta_a
  p_a0 = 1 - p_a1
  
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1a0 = 1/(1 + exp(-dat_a0m1[, idx_m]%*%beta_m))
  p_m1a1 = 1/(1 + exp(-dat_a1m1[, idx_m]%*%beta_m))
  ratio_M = p_m1a0/p_m1a1
  ratio_M[dat$M == 0] = (1 - p_m1a0[dat$M == 0])/(1 - p_m1a1[dat$M == 0])
  
  effect = mean( {IA1/p_a1}*ratio_M*Y - {IA0/p_a0}*Y )
 }
 
 # +++++++++++++++++++++++++++++++++
 # Mixed:
 # NDE = 1/n*\sum_i{ {I(A=0)/p(A=0|c_i)}*E[Y|a=1,m_i,x_i] - E[Y|a=0,m_i,x_i] }
 # +++++++++++++++++++++++++++++++++
 if (estimator == "Mixed" && reparam == FALSE){
  
  IA0 = 1 - dat$A
  
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a1m = dat_a1m[, idx_y]%*%beta_y
  y_a0m = dat_a0m[, idx_y]%*%beta_y
  
  effect = mean( {IA0/p_a0}*y_a1m - y_a0m )
 }
 
 # +++++++++++++++++++++++++++++++++
 # AIPW:
 # E[Y(1, M(0))] = {I(A=1)/p(A=1|c_i)}*{p(M|a=0,c_i)/p(M|a=1,c_i)}*(Y - E[Y|a=1,m_,ic_i]) 
 #      + {I(A=0)/p(A=0|c_i)}*{ E[Y|a=1,m_i,x_i] - \psi} 
 #      + \psi 
 # where \psi = (1/n)* \sum_{i, M} E[Y|m,a=1,c_i]* p(m|a=0, c_i)
 # 
 # NDE = E[Y(1, M(0))] - E[Y(0, M(0))]
 # +++++++++++++++++++++++++++++++++
 if (estimator == "AIPW" && reparam == FALSE){
  
  IA1 = dat$A 
  IA0 = 1 - IA1
  
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1a0 = 1/(1 + exp(-dat_a0m1[, idx_m]%*%beta_m))
  p_m0a0 = 1 - p_m1a0
  p_m1a1 = 1/(1 + exp(-dat_a1m1[, idx_m]%*%beta_m))
  ratio_M = p_m1a0/p_m1a1
  ratio_M[dat$M == 0] = (1 - p_m1a0[dat$M == 0])/(1 - p_m1a1[dat$M == 0])
  
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a1m = dat_a1m[, idx_y]%*%beta_y
  y_a0m = dat_a0m[, idx_y]%*%beta_y
  y_a0m0 = dat_a0m0[, idx_y]%*%beta_y
  y_a0m1 = dat_a0m1[, idx_y]%*%beta_y
  y_a1m0 = dat_a1m0[, idx_y]%*%beta_y
  y_a1m1 = dat_a1m1[, idx_y]%*%beta_y
  
  psi_a1 = mean( y_a1m0*p_m0a0 + y_a1m1*p_m1a0 )
  psi_a0 = mean( y_a0m0*p_m0a0 + y_a0m1*p_m1a0 )
  EY_A1M0 = mean( {IA1/p_a1}*ratio_M*{Y - y_a1m} + {IA0/p_a0}*{y_a1m - psi_a1} + psi_a1) 
  EY_A0M0 = mean( {IA0/p_a0}*{Y - y_a0m} + {IA0/p_a0}*{y_a0m - psi_a0} + psi_a0) 
  
  effect = EY_A1M0 - EY_A0M0
 }
 
 return(effect)
 }


